{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d71afc97",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Module Doc String\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import ticker\n",
    "import seaborn as sns\n",
    "from fancy_einsum import einsum\n",
    "from transformer_lens import HookedTransformer\n",
    "from toxicity.figures.fig_utils import convert, load_hooked\n",
    "from constants import ROOT_DIR, MODEL_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e0137eae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-medium into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = load_hooked(\"gpt2-medium\", os.path.join(MODEL_DIR, \"dpo.pt\"))\n",
    "model.tokenizer.padding_side = \"left\"\n",
    "model.tokenizer.pad_token_id = model.tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f744b3b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "prompts = list(\n",
    "    np.load(os.path.join(ROOT_DIR, \"toxicity/figures/shit_prompts.npy\"))\n",
    ")\n",
    "tokens = model.to_tokens(prompts, prepend_bos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ed241c56",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 73/73 [00:17<00:00,  4.29it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "batchsize = 4\n",
    "all_dpo_prob = None\n",
    "all_gpt2_prob = None\n",
    "for idx in tqdm(range(0, tokens.shape[0], batchsize)):\n",
    "    batch = tokens[idx : idx + batchsize].cuda()\n",
    "    with torch.inference_mode():\n",
    "        _, cache = model.run_with_cache(batch)\n",
    "\n",
    "        accum = cache.accumulated_resid(layer=-1, incl_mid=True, apply_ln=True)\n",
    "\n",
    "        # Project each layer and each position onto vocab space\n",
    "        vocab_proj = einsum(\n",
    "            \"layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab\",\n",
    "            accum,\n",
    "            model.W_U,\n",
    "        )\n",
    "\n",
    "    shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu()\n",
    "    if all_dpo_prob is None:\n",
    "        all_dpo_prob = shit_probs\n",
    "    else:\n",
    "        all_dpo_prob = torch.concat([all_dpo_prob, shit_probs], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "61a1f967",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2-medium into HookedTransformer\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 73/73 [00:16<00:00,  4.37it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = HookedTransformer.from_pretrained(\"gpt2-medium\")\n",
    "model.tokenizer.padding_side = \"left\"\n",
    "model.tokenizer.pad_token_id = model.tokenizer.eos_token_id\n",
    "\n",
    "for idx in tqdm(range(0, tokens.shape[0], batchsize)):\n",
    "    batch = tokens[idx : idx + batchsize].cuda()\n",
    "    with torch.inference_mode():\n",
    "        _, cache = model.run_with_cache(batch)\n",
    "\n",
    "        accum, accum_labels = cache.accumulated_resid(\n",
    "            layer=-1, incl_mid=True, apply_ln=True, return_labels=True\n",
    "        )\n",
    "        vocab_proj = einsum(\n",
    "            \"layer batch pos d_model, d_model d_vocab --> layer batch pos d_vocab\",\n",
    "            accum,\n",
    "            model.W_U,\n",
    "        )\n",
    "\n",
    "    shit_probs = vocab_proj.softmax(dim=-1)[:, :, -1, 7510].cpu()\n",
    "    if all_gpt2_prob is None:\n",
    "        all_gpt2_prob = shit_probs\n",
    "    else:\n",
    "        all_gpt2_prob = torch.concat([all_gpt2_prob, shit_probs], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ae18867b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAACcCAYAAADrhbcmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwD0lEQVR4nO3deXgUVbrA4V/1mqSzdvbFhD0sssgFWZVV9CIYBkVARhxkVFREdATEdUBAjaIoXBFnYBQdkTjCgAsMIIuEUZRNhQCBQAxkJ+nsnd6q7h8NLZEkdIckhPa8z8NjuvrUqa/azpc6VWeRFEVREARB8EKqqx2AIAhCUxEJThAEryUSnCAIXkskOEEQvJZIcIIgeC2R4ARB8FoiwQmC4LVEghMEwWuJBCcIgtcSCU4QBK8lEpwgCF5LJDhBELyWSHCCIHgtkeAEQfBaIsG1AImJiXTp0oXi4uJL3rv33ntJTEzk7NmzHtdbXFzs9r5Lly7loYce8vgYgtCSiQTXQgQEBPDll1/W2JaTk8PRo0evUkSCcO3z+gRnNps5cuQIZrP5aodSr5EjR7Jx48Ya2zZs2MCIESNqbDt79izTp0+nT58+DBo0iJdffpnq6moAZFnmzTffpG/fvvTr1481a9bU2DcvL4/HHnuMvn37MnToUN555x0cDkfTnpggXEVen+BOnTrF2LFjOXXq1NUOpV4jRozgxIkTZGZmurZt2LCBP/zhD67XVquVKVOmEBYWxs6dO0lJSeHQoUMsWrQIgLVr17Jx40bWrFnD1q1ba1z9ORwOpk2bRmRkJDt37mT16tVs2rSJjz76qNnOUWh81dXVlJaW1vh34Q9efeXcKVNXuWuJ1ye4a4XBYGDYsGGuq7gff/wRg8FA27ZtXWX2799PYWEhzzzzDL6+vkRGRjJr1iz+/e9/I8syX375JRMnTqR169b4+/sza9Ys176HDx/m9OnTzJkzBx8fH+Li4pg2bRopKSnNfq5C47FYLJw5c4asrCyysrI4c+YMFoul3nLulKmvXGOx2eUmq/sCTZMfQXDbHXfcwcKFC5kxYwb//ve/a1y9ARQVFREeHo5Op3Nti4uLw2KxUFRURGFhIVFRUa73YmNjXT9nZ2djtVrp16+fa5uiKEiS1IRnJDSHi5dVqW+JlQvvuVPmcuWulKm8mrP5FSQmhKDTqpvsOCLBtSADBw6ksrKSffv2sXXrVmbMmFHjSxYdHU1hYSFWq9WV5LKystBqtQQFBREREUFOTo6rfEFBgevnyMhIgoOD+fbbb13bSktLKS8vb4YzE4RfVZptnDpbis3uwFRuIdLo12THEk3UFkStVnP77bfz17/+lR49ehASElLj/W7duhEbG8uiRYswm83k5+fz+uuvc/vtt6PT6Rg7diwfffQR6enpmM1mFi9eXGPf0NBQ3nrrLSwWCyUlJcycOZOFCxc292kKv2MWm4OMsyVoNSoCDHryi6uQ5aa7UhQJroVJSkrixIkTjBkz5pL3tFot7777Lvn5+QwePJikpCS6du3Kiy++CMAf/vAHJk+ezJQpUxg0aFCN+3darZYVK1Zw/PhxBg0axK233orRaOSVV15prlMTfuccDpnT2SXYHDIWq4Nqq51qq53yKmuTHVPy9nVRjxw5wtixY1m3bh1dunS52uEIQqMqLS0lKyurxrb4+HiCgoLqLedOmbrKNYQsK/ySV0ZRqZmyCivL/vUjPjo1Myf0JCRAT/v4kMtX0gDiCk4QhCYlywq5RZUUmKoAWPX5EWx2mfIqG2mnizBVWDBb7E1ybJHgBEFoMnaHTGZuGWcLyjHotbz/RRqllVa0Gmfq2X0oGwkoLm2ajvgiwQmC0CSqrXZOnDFRVGom2F/P+l0n+SWvHD+9hhl398BHp6bAZCa7oIICkxm7o/H7xV2VBLd7926SkpK47bbbmDJlSo3uDL9ltVoZN24c77zzTjNGKAjClSivsnI0s5hqi4OQAB9SD+XwfVo+kgSTR3YiLiKAPl2cfTb3/JSD3SFTUt74nYqbPcEVFxfz1FNPkZyczObNmxkyZAhz586ts/z8+fM5c+ZMM0YoCEJDybJCfnEVxzJNaFQqAvx0HDlVxIbdGQDccVNbEhOMAPTvGoMkwbFfTJRWWMk3VTV652KPO/p+++23NXrDeyo1NZXExEQSExMBmDBhAsnJyRQWFhIeHl6jbEpKClarlcGDB1+23oKCAgoLCy/ZnpGR0eBYBUFwX1W1jay8MsoqbQT66XAoCut2nmT3oWwAeneKZNANztE1ZVVWVCqJzq1DOXKqiO/T8hh+YzyVZhv+frr6DuMRjxPc7NmzUavVjB49mqSkJNq1a+fR/nl5eURHR7te63Q6QkJCyM3NrZHgfvrpJ1JSUvjwww+ZN2/eZetdu3Yty5Yt8ygWQRCunENWKDBVcbagHK1KhTHQh4zsEj7Zcpxzpc7B+v2uj+YPg9shSRJV1TbUKomYCH/+p2MER04Vse9oPjf1iKWs0np1E9yuXbv49ttv+fzzzxk/fjwJCQmMGTOGUaNGYTQaL7t/XeMfVapfW8vFxcU8++yzLFu2DF9fX7fiGj9+PEOHDr1ke0ZGRo1B54IgNA6HrFBWaSH3XCUVZudVmywrrN91kt0Hs1GA4AA944d3oOP5ZqnF5sBml+nYyohWo6J1TCAxYQZyzlVy8HgBraIDGzVGjxOcSqViwIABDBgwgHnz5rFr1y6WL19OcnIyAwcOZOLEiQwaNKjO/WNiYvjuu+9cr61WKyaTiZiYGNe2r7/+msrKSmbMmAFAbm4uOp2O8vJy5syZU2u9ERERREREeHo6giB4yGZ3UFxmIa+oEovVjo9OgzHAhxNnTKzdlk7R+au2vtdHccdNbfHVO9OM3S5TabbRIT4Eg68WgMhQA707R7Hhmwz2HsnjjpvaNGqsDR5s/+OPP/LFF1+wefNmNBoN999/P7Gxsbz66qts3769zmblgAEDWLBgAenp6XTo0IFPP/2U7t2717j6GzduHOPGjXO9fvrpp4mPj+eRRx5paLi/Oxs2bGDNmjWYTCYAjEYj06ZNY9CgQSxdupSPPvqIqKgoJElClmW0Wi0zZ86kf//+jB07FgCbzUZGRgYdO3YEICwsjJUrV7JlyxaWL1+OLMuoVCpmzpxZ7x81wXvkFlWSU1iBoij46rUYAn0xV9v5ZOtx9h7JA5xXbXcPbU+n1qGu/WRZobTKQqvoIIID9K7tYUG+XN8mlK9/yKKs0sqhE4W0irnykRMXeJzgFi9ezKZNmzCZTIwYMYLk5GT69u3ranZef/31TJo0qc4EZzQaefPNN5kzZw4Wi4XQ0FCSk5MB5zjMBQsW0LVr1ys4pealKAoWa9POiqvXqT2a1mjp0qVs2rSJJUuW0KFDBwCOHTvG1KlTXfcphw8fXmOg/ddff8306dPZsWMHGzZsAJyzB48YMcL1GpzTLj3//POkpKSQkJDA0aNHmTRpEtu3byc4OLgRzlZoqfKLq8jKLSPIX49GrUJRFA6dKGT9zpOUVTrHkw7sHsPtA1rjo/s1tdjsMmVVFmLC/IkIqXnLyVevITLUj96dItlx4CwZZ0sbNWaPE1xaWhozZsxgxIgR+Pj4XPJ+TEwMb7zxRr119O/fn/Xr11+y/eJfpIu11AHhiqIwZ1kqRzMvXSymMXVqZeTV6QPdSnJFRUWsWLGCNWvWuJIbQMeOHZk/f36dU5T379+f6upqsrOz672XqlKpeOmll0hISACgffv2SJJEUVGRSHBezFRezS+5ZQQZ9KhVEkczi9n0bSZn8p3TbUWE+DJ+eCJtYn+9+lIUhYoqG7Ki0Co6iPBg31q/w+EhfvTrHkWAQUe/rtGXvH8lPE5wMTEx3HHHHZdsnzlzJkuWLCEkJKTWm/1C8zh48CAGg6HWq+Bhw4YB1JgTDpxfxDVr1hAeHk779u3rrT86OrrGU/C33nqL6667jjZtGvfeidByVFRZOXm2BIOvltO5pXz130wyc8sA0GlVDO55HcN7x7uGX8GvV23B/j7ERwW47sPVxt9XS2igH51aQZC/vs5yDeFWgsvNzWXz5s2A8yrrt1/m8vJyUlNTGzWwa4EkSbw6fWCLaqLW1lHynnvuobKykurqajp27Ei7du3Ytm0bhw8fBpwPehISElixYkWtV+W1sVqtLFiwgL179/KPf/xDzAzspWx2mVN5JejUatbvPOm6z6ZVqxjQPYZhva6r0a1DURTKzTYUWaF1dBBhwb6oVJf/bkQa/Sg8Pxi/MbmV4CIjIzl48CAmkwm73c727dtrvK/T6XjhhRcaPbhrgSRJ+NTz16m5devWjfLyco4dO+Z6OPDxxx8DsG7dOteaD7+9B+eJc+fOMX36dAwGAykpKY0ynY7Q8jhkmXOlZhRFzcbU03yflodKgv7dYhjeO/6Sqy2r3UF5lRVjgA/XRQZ49HsRaNARHOjeH1dPuBWBSqXi7bffBuCll17i+eefb/RAhMYRGRnJgw8+yKxZs1i8eLHrPlxxcTF79uxBrb6y+e8rKir44x//yMCBA3nmmWdq9F8UvEul2Y7Kx87mb8/yfVoekgR//N9O3NChZncsRVEoq7IiAW1jgzAGunfVdjFJkogN88dia9zWkNsp9sIVwdixYzly5EitZcSEki3DzJkzad++PfPnz6e01PlUSqVSMXToUJ599ln++c9/NrjulJQUTp8+jU6nq7EozrX29Fuon6JASYWFb45m8+3hAiQJJt3asUZysztkzBY7NrtMSKDeedWma3hr5uLuI43F7Rl9e/bsyYEDB1zNnksqkqQWuQq7mNFX8GZNNaOv2WJj/Z5CDmdVIwETRyTSu3MUNrtMlcWGLCto1CpCAvQEB/gQ5K9rkfdh3U63Bw4cAJxXcoIgeLcDJ8s5nOUckTD+lg707hyF1e6g0mwnJsyPQH89fj5a1B42RZub2wmurmbpBZIk0blz5ysOSBCEq0uWFQ6crARgeK8Y+nSJxu6Qqaiy0jYumNAg98aHtwRuJ7g777yz3vdbahNVEATP5BRVk2uyATCga4RzmFWFhfjowGsquYGHDxkEQfB+3x93duJNiNAT4q+juNJCdJiBqCZcoLmpePwUta6mqmiiCsK1r9rqIO2Ms8Pt9fG+mK12jEEBxEYEtMiHCJfjdoK75557OHDgQJ1NVdFEFYRrX1pWBVUWBV+dRIxRi16nJiEqsMU/TKiLeIoqCALg7LB74EQFAB3jfFEUhUA/58wh16oG9crLzMxk06ZNFBYWEhcXx+23305kZGRjxyY0QGJiIu3atUOj0SDLMg6Hg1tuuYXHHnsMjUbD2bNnGTZsWI3+jFarlWHDhvGXv/zF1QzZunUrK1euxGQyodPpiIyM5JFHHqFnz55X69SEJlZYWs3pAufKVokxPmi1KnTaKxv5crV5nOC2bdvGE088QZ8+fYiKimL37t0sW7aMFStW0Lt376aIsUVTFAXF1vjLnV1M0uo9uv+xcuVKoqKcS7KZTCamTZtGZWUlzz33HABqtbrG1FTl5eWMGTOG6OhoJk2axMcff8z777/P4sWLXaMTUlNTefjhh1m8eDEDBw5sxLMTWopvDuWiKBAVrMXfTyLEX+fxkKuWxuME9+qrr7J06dIaK1199dVXLFq0qNY53ryZoijkrH4Wy9njTXocfVxHYiYvaNBN3pCQEObMmcPkyZOZOXNmrWUCAgK4/vrrycjIwGq1snjxYt56660aQ68GDhzIww8/zKuvvioSnBeqttrZ85NzfeJO1/kgAb4+LWcSiYbyuHFdVlbGTTfdVGPbLbfcQmZmZmPFdI1p+X/hOnbsiM1m49SpU7W+n5GRwd69e+nfvz/p6elUVFTU2hTt168f6enplJWVNXXIQjNyOGR27j9LYUk1GrVEfLiGAIMWtRdMpOBxiv7f//1f3n//faZOneratnbtWoYMGdKogV0LJEkiZvKCFtdEvWT/8/teWKHM4XCQlJQEgCzL6PV6pk+fzvDhw13dgOx2+yX1WK3WBscgtEyKonC2sII9P+YA0D7GB61ahb9v4w98vxrcTnCjR48GnF/yTz75hI8//piYmBgKCgr45Zdf6NGjR1PF2KJJkoSka/x5rBrTzz//jK+vL/Hx8RQWFl5yD+5i7dq1IygoiO+//57hw4fXeO+HH36gffv2BAY27tJuwtVTaDLz88lzpJ0uAqBjjB4fnRof3bX9cOECtxPc/fff35RxCE0kPz+f119/nfvuuw+9/vJ/lfV6PbNmzWLhwoVERka67sN98803vPvuu7z22mtNHbLQTMoqrRw5XUTKtnSsdpm2sQGEBqobfdrwq8ntBHfx3F+1qWsxE6H5TZ06FY1G41oScNSoUTzwwANu7z9u3DjCwsJ45ZVXMJlMOBwO4uLieOedd+jVq1cTRi40F7tdJuNsMf/6+gSmcgthQT5Mvq0dOdln610/4Vrj8ZlkZmayfPly8vPzkWUZcK6fmZmZecliJkLzO368/ie6cXFxpKWlXbaeIUOG/C7vq/5elFRaWL/rJFn55fjqNTyQ1BW9RibIC7qGXMzjxyTPPvssxcXFhIeHoygKN9xwA1lZWUyaNKkp4hMEoZHJCnyemsXhjCJUKokpozoTGuSDAvj5aK92eI3K4wR35MgRlixZwgMPPIBWq+WJJ57grbfe4ptvvmmK+ARBaGSHTpbx9f5cAMYNbU/buGBKKy0E+evQaa79riEX8/hsAgMDMRgMJCQkkJ6eDkCvXr345ZdfGj04QRAa3+4jznU6BveMo0+XKErKq4k0Ggi4aPk/b+Fxgmvfvj1///vfUavVBAUFsX//fo4cOeLxak27d+8mKSmJ2267jSlTplBQUHBJmfT0dO69916SkpIYOXIkf//73z0NVxCEi1SYHeSXOCezHPo/11FaacUY5Etc5LU5HdLleJzgZs2axWeffUZeXh6PPvookydPZty4cfzpT39yu47i4mKeeuopkpOT2bx5M0OGDGHu3LmXlHvssceYOHEiGzZsYM2aNaSkpLBr1y5PQxYE4bz0bOdcb7Fhfsgo+PtqrunpkC7H46eoHTt2ZNOmTYDziVyvXr2oqKi4ZLX7+qSmppKYmEhiYiIAEyZMIDk5mcLCQsLDwwHnk9mpU6cyYsQIAIKCgkhISCA7O7vWOgsKCigsLLxke0ZGhkfnJwje7EKCax8XiFajpnVsMFovu+92sQZ1ePnuu+/44osvKCwsJDY2lrvuusuj/fPy8oiOjna91ul0hISEkJub60pwWq2Wu+++21Vm165dHDhwgHnz5tVa59q1a1m2bFkDzkYQfh8URSEj1wxAu7gA2sQEor/Gp0O6HI8T3Nq1a0lOTmbUqFH06NGD7Oxs/vjHP/Lyyy9z6623ulWHoii1tvfrWiU9JSWFxYsXs3TpUmJiYmotM378eIYOHXrJ9oyMDGbNmuVWXILgzfJMViqrZTQqaBMb4HVdQmrjcYJ75513WLVqFd27d3dtu+OOO3jxxRfdTnAxMTF89913rtdWqxWTyXRJ8rLb7cyfP5///ve/rF692tWkrU1ERAQRERF1vi8Iv3fp2c6rt2ijlgBf7+rQWxePG98Oh4NOnTrV2NajR49an4LWZcCAAaSlpbm6mXz66ad0794do9FYo9zs2bM5efIkn332Wb3JTRCEy3M9YAjVedVwrPp4fJZ33303r732GrNnz0ar1WK323n77bcvO1b1YkajkTfffJM5c+ZgsVgIDQ0lOTkZgKSkJBYscE7u+OWXXxIfH8/kyZNd+06YMIGJEyd6GrYg/K5ZbA6yCpwr1ceFadFqvffBwsXcTnA33HADkiShKApms5mUlBRCQ0MxmUyYzWZiYmJcU2K7o3///rXOAHzxND6XG1cpCIJ7jv9SikMGg4+K6BC9V0xm6Q63E9yKFSuaMg5BEJrQjyed873FGrUE+DXuwwVFkZGklpkw3U5wN954o+tnh8PBzz//TE5ODuHh4fTs2dPjkQyCIDSfn0+ZAGfztLFWylLsNmylhdhLClDpfVH5B6PWG1DpfJDULeMen8dR5Obm8tBDD3HmzBnCw8MpKCggMjKSVatWERsb2xQxCoJwBc6VmMk553zAEB+uv+IEp8gO7OVF2M7losh21D4GFIcN+7lsbChIkgqV3g+V3g9J54NKo0PSaJF0Ps1+pedxgnv55Ze54YYb+Ne//oVOp6O6uppFixaxYMECli9f3hQxCoJwBQ4ed/ZwCAvUEBbkwWy9ioxsrcZRCYosO1/LDhwl+cg2CyqtHktOBrLVgi6yFdqQKCSVCkV2oNht2MuLURwOJElBAdR6A9rQGFS+zTfu1eME98MPP7Bz5050OufMAz4+PsydO5ebb7650YMTBOHK7T/mTHDXhWk96B6iYCsvxu6nxlJqQUECRUGSQJFUmE//RMXhb5DN5a49JJ0PuogE9JGtUQcYUfv4o/L1R/LxR633Q7FVU52djto3EK0x2vleEyc6jxOcVqulrKzMNaQKnAsHX1ixSRCElqPSbONg+oUEp3N7aJa9shSpJBfZV8EuOUCWUWQH1sIsKtL2oFicTV61IQhNUATWgl9QrNVYzh6vdZ1gSavHv8tNGK6/GdlWTfXZ46gDQtCHxyNpmm5EhccJ7rbbbuPxxx9n1qxZxMTEcPbsWRYvXuz2KAZBEJqeoigUlZr56r+ZVFXb8dGpSYjwcWv0gmypQnMiFb+sfVQegspayqgDwwjoNgS/tjcgqTUosgNbcS7W/EysBb/gMJcjV1cgmyuQLVUoNgvlh7ZRefw7Am8YgW/7XjgqS7HYT6KLaoNK2zQL3Xic4J588kleeOEF7r33XhwOBzqdjqSkJJ588smmiE8QBA/ZHTL7jxXw6dfppJ0uBqBTQhCBhstfKSl2K1LGXgxZ+wCQDCGoNRqQVEgqNSq9H36JffBt1RXpfF86xW4DtRpdWBy6sDjoMrBmnbKD6l+OULpvE47yIkr+u46KI6kE3TgKKTQWS84J9NFtUekavxXocYLbtWsXL730EgsWLKC0tJSwsDCvnChPEK5FFpvMp9tPseWHHGx2GUmCAd1iGNYzApWtpN59FdmBkvUz/hm7nXXFdiV86L0E1JIXFUVBtlQhW6uRtHqw2FAUkNQaVFp9jWanpFLj27obPvGdqTy+l/JD27CXFlC0dRX+3YZg6DQAS/YJ9DHtUOn9GvPj8DzBPf/88+zZswetVlvjPpwgCFeX3aHw3qYccoutALSKDuSuIe0xBvqgV1uxVF7aRUNRZBSbFcVuxXrmGIajm5EUGVtYG8xt+oPsQLY7nN07JBVIIFvMKHYrat9AfCISUPkGoNityBYzclUZjsoSHJZKOP9gAklCQgFUGDr2w69dT8oObKEybQ8VP+3AVpRNcL8xVJ89gT62HWofQ6N9Jh4nuD59+pCSkkJSUhL+/v6NFoggCFdm77FScout+OjU/GFQO3p0CKey2oZDUQg06Cm86Gaao7qS6vI85/0xRcZRWUbl16tROWzYA6OoTByKYnWOXUWWkR1WUGRQFFR+AWgjW9d4Cipp9c77aP7BKEocisPuTG7n/ymKjKPsHLbSQiSVhqA+o9GFx1OS+i8s2emc2/w3ggfciSY49OomuGPHjrF161YWLFiAj49PjebpgQMHGi0wQRDcZ7Y42P5jCQCjB1xHYkIw1VY7cZEBhAf7Ulnxa3cOxWHDcvoQclkejopi7CWFWIuyUaxmHH4hVHa+DdlqRR0QjCY4HN/gEOd+ssOZsFTqem9LSZIKSXPpAjZqHwOaoHCspjwc5cXoYzsQPno6RV9/gKO8mKJt7xM2chq60LhG+1w8TnALFy5stIMLgtA4dvxkwmyVCfFX071dCMYQAxFGv1q7hegOrMOcexTzb7ZLhhAqutyOwyGj9g9E4x9SY+SBpLryIV4qvR8+UW1wBEVgK8pB0VoJv/0RTLtTsGSnU52VRmC3xltw3KMEV15ejsFgoF27duj1TfNYVxAEz5gqbHx7tAyA3u0NGAN9uS4yoNayUkEGutyjIEno4xLRBkWgCY5EGxyJJSia4hPpqP0C0ASEAk338FDt648qth320kJs584SPPBubMXZ+La6vlGP43aC279/Pw8++CCVlZWEh4fz7rvv0qVLl0YNRhAEz/1nXzEOGWKMWtpG6fH3rePXWpHRH3YuGKVr35fQ/kkgO1AUZ0fearsDyS8ATYCRpkxuF0iSCm1wJGrfAKyFZ1ACjI0+SN/tka9LlixhxowZHDx4kLvuuos33nijUQMRBMFzp3PK+SnT+fSgdzs/woJ961zbRJ25H3XFORS1Dm3nm5GrK5z93lQaVD7+qAyBaAOMzqelzUil90Mf0w5dVBskrU+j1u12ujx69CgffvghAH/+85+55ZZbGjUQQRA8oygK/9xyEoC20TpaR/vWOdZUtprRHdsOgDnuBoIDjPjGta5xX626tBSkkiaPuzaSSo0uJKrR63U7VSuK4vrZYDBgt9sbPRhBENy35btfOJZViloFfdr5ERJQ99VP1fefo7KZcfgEYYnogMoQ1CgPDVo6t6/gLk5wgiBcPUczi3j/8zTSMp3DsLrE+3JdlH+dCzjbTHlYftoKQFX8/6AKDEVVSzcOb+RRgktLS3MlOofDUeM1IB46CEITOlNQzj8+P8IPafkAqFQSfTuH0fU6iQC/8wlLkXFUmLDaKpyjFBQZ07YPQHZgC4rFbmyNzjfwKp5F83I7wZnNZsaOHVtj28WvJUni6NGjjReZIAgAVFvtfLTpGF+knsIhOy8ourcP46YescSHa6ksKUQlgercaXQH/02JuZSS31YiSVTF/w+aoFD4nSw4Ax4kuGPHjjVlHIIg1OL7I3ks/+xHzpU6h021vy6YQT1jaR0TTEyYAbVSzdkK0B7Zgjbj2/NjPi+l6zgQJTS+0Qezt3QtY2UIQRBqyC+u4u8bfua7w3kABPnruLVPAj07RRITZiDQoEetkijO/AWfb/6GusxZzhLRHuOAsRhUChLn751LEmZdIJqy6qt4RleHSHCC0AJYbA7SThWx72g+B48XcKagAgBJgr7XR3FL7+to5VOOT9VR7D/lU1xaiK20gOqsNNR2G7JGT2WrPtgjEokKuw5ff4NrHQUUGZvVgVRZcJXPsvmJBCcIV0lVtY0f0vL55tBZDh4vxGaXa7zfJtKPsZ3stLZ9j7z1PcorSyivpR5bUAwVrfujMsai8wtCpfdD7VfzQYKqtLQJz6TluioJbvfu3bz++utYLBaio6N59dVXiYiIqFHGZDIxd+5csrKycDgczJo1i+HDh1+NcAXhilVb7BSXV2Mqs5BfXMXew7n8cDQfg1xOK3Uhw7UlBPrZCfdTCPGRCdLa0ZZlw0/VXOhxKqm1aIIjUPsFovILRG0Iwh4YRYndF22A0TnxpFBDsye44uJinnrqKVavXk1iYiKrV69m7ty5rFy5ska5efPm0aFDB959913Onj3L+PHj6dKlC9HR0c0dsiDUyeGQKau0UlpppbTCgqm0ClPhOcqKi6gymbCUm7BWVSE77KiQUUsKWux00Ji41b+QYFVVzQqt5/+dp9L7oY/tgE9cIj7XdULtH4ykdS6sLKk1lFVZ0Obk0BxjR69FzZ7gUlNTSUxMJDExEYAJEyaQnJxMYWGha4Zgu93Ojh072LTJOTA4Li6OgQMH8vnnn/Pggw82WWynDh/mXObJyxe8Vjo9/ybOWqN251zcPN/fdga/+Iner29dvq6a9Siu/yiuH5SLCzsnVDz/84Uy0oXy5ydbdP73wgSMsvM9Wa5Z5vzKUSjy+UHozu2KfH5fhw3JXo3aYUEtW9DIVjSKHa1kR4sDveQgXnIQ/9sTqm94pSShDgpHZ4xB5WNApfNF0vuh0vuiCYrA57pE5/J7Oh8k9aVzh6usMiK51a3ZE1xeXl6NqzCdTkdISAi5ubmuBGcymaiuriYq6texaVFRUeTm5tZZb0FBAYWFhZdsv9A3LyMjo964qquqyP/4r2ikayR5CVeFAtjP/7PUU8Ym6ZDVPqDzQaPTodGo0ajVSCoVkkqF2j8EfVRr9BGtnDPjqrWuRVxqyCkGiuuMp6Kigvz8fNcfBUmSqKqqumS27YvLuVOmvrpakjZt2tS7ZGmzJ7gLH/BvXTwDwsUfcF1lfmvt2rUsW7aszvdnzZrlaaiCILRw69atq3cEVbMnuJiYGL777jvXa6vVislkIiYmxrUtNDQUvV5PQUEBkZGRAOTn59OuXbs66x0/fjxDhw69ZHtZWRkZGRl07tz5spN0njt3jnXr1jF27FjCwsJc8b3//vv86U9/QqfTuVWmtrrcKdPcx2uJMYnjieNdXO5y2rRpU+/7zZ7gBgwYwIIFC0hPT6dDhw58+umndO/eHaPR6CqjVqsZNmwYH3/8MU888QTZ2dns3r2badOm1VlvRETEJU9iL+jXr59bseXm5mI0GunQoYOrGW2xWDAaja4E6U6Z2upyp0xzH68lxiSOJ453cbkr1eyD0oxGI2+++SZz5sxh5MiRbN68meTkZACSkpL4+eefAefyhBkZGYwaNYqpU6fy9NNP06pVq+YOVxCEa9hV6QfXv39/1q9ff8n2DRs2uH42Go313lNrLmq1mkGDBqFW1z13VmOVae7jtcSYxPHE8RqTGMlwGRqNhsGDBzdLmeY+XkuMSRxPHK8x/X7mTREE4XdHJDhBELyWSHCCIHgtkeAEQfBaIsEJguC1RIITBMFriQQnCILXEglOEASvJRLcRfz9/Rk0aFC908O4U6Yx62qJx2uJMYnj/T6O5ylJEUvWC4LgpcQVnCAIXkskOEEQvJZIcIIgeC0xm8h57ixl6I41a9bw8ccfI0kSvr6+PPvss3Tr1u2KYvvxxx+ZNGkS27Ztq7FOhSdOnDjB/PnzKS8vR6VS8cILL9CjR48G1bVt2zbeeustVCoV/v7+zJ8/n7Zt23pUx5IlSygsLGThwoWAc6qs9957D7vdTqdOnViwYIFbN5t/W8+yZcvYtGkTKpWK0NBQXnzxRVq3bt2gmC7YunUrTzzxBIcPH27w+e3bt4/k5GSqq6sxGAwsWrTIrbh+W88nn3zC6tWrUavVREVFsXDhwst+T+v6Tq5cuZJPP/0Uh8PBwIEDeeaZZ9BqL13Y5nJ1derUiZdffpm9e/ciSRIJCQnMmzfP7Vl5ExMTadeuHRrNr+koOjqad999163966UISlFRkXLjjTcqx44dUxRFUT744APl/vvv97ie/fv3K4MHD1aKiooURVGU7du3KwMGDFBkWW5wbOfOnVOSkpKUDh06KLm5uQ2qw2w2KwMHDlT+85//KIqiKDt27FAGDx7coLjMZrPStWtX5cSJE4qiKMrq1auVSZMmub3/mTNnlEceeUTp1q2b8swzzyiKoijp6elKv379lLy8PEVRFOXll19Wnn/+eY/r2bhxozJ27FilsrJSURRF+fDDD5W77rqrQTFdcPLkSWXo0KFKp06dGnx+eXl5Su/evZVDhw4piqIoH330kXLPPfd4XE9WVpbSs2dPpbCwUFEU5+c0e/bseuup6zu5c+dO5bbbblPKysoUu92uzJgxQ1mxYkWD6nrnnXeUhx56SLFarYqiKMorr7yiPP744/XWdbEr+W5fjmiiUvtShnv37q11la76BAUF8dJLL7mmX+/WrRtFRUWYzeYGxWW323nyySeveMGc1NRUwsPDGTFiBACDBg1i+fLllyzz5w6Hw4EkSZSeXym9qqoKH5/61sWrae3atfTv358pU6a4tm3bto1Bgwa51t+YNGkSn3/+ObIs11VNrfUkJCTw3HPP4efnBzg//+zs7AbFBM5VpmbNmsWzzz57Ree3efNm+vbtS/fu3QEYN24c8+bN87geWZZxOBxUVVWhKIpbn31d38mtW7dy++23ExAQgFqtZuLEibVOQutOXZ06deLJJ590Xf117drVrc+9OYgmKu4tZeiOtm3buppqsiyzaNEiBg8e7PqF81RycjJ9+vRhwIABDdr/gtOnTxMREcFzzz1HWloa/v7+PPXUU/WuUlYXg8HAvHnzuO+++zAajVgsFlavXu32/n/5y18AWLp0qWtbbm5ujc8/KiqKqqoqSkpKaqzVcbl6Lr4VYLFYeO211xg5cmSDYgKYO3cu9913Hx06dLhsHfXVdfr0aQwGA08++aTr/8XTTz/tcT0JCQlMmzaNkSNHEhQUhF6vZ82aNfXWU9d3Mjc3lxtuuMFV7nLLctZX18UTV5aUlPB///d/jBs3rt66fmvq1Kk1mqjJycmuC44rIa7gcG8pQ09UVFQwffp0srOzefXVVxtUxxdffEFWVhYPP/xwg/a/mN1uZ8+ePYwZM4Z169YxdepUHnzwQSoqKjyu6/jx47z99tts3LiRb775hueee44HHniAqqqqy+9cj9o+/9q2uaOgoID77ruPgIAAZs+e3aA6VqxYQVhYGElJSQ3a/2IXFjJ/9NFHWb9+PUOGDKl3AaW6pKam8sUXX/D111+TmprKhAkTmDZtmltX4rV9J3/7+br7edf1/T516hT33HMPvXv35r777vPgzGDlypVs2LDB9a8xkhuIBAc4lzLMz893va5tKUN3nT59mrvuugt/f38++OADAgMDGxTTZ599RlZWFmPGjHH9kk2dOpV9+/Z5XFdkZCStWrWiV69egLOJqtFoOHXqlMd1paam0rVrV9dybaNHj8bhcFx2Ye36/Pbzz8/Px2AwEBQU5HFdP/74I3feeSe9evVi2bJlriXpPLV+/Xr27dtHUlISDz74IA6Hg6SkJLKysjyuKzIykh49eriufsaOHUtmZibFxXUv6Fyb7du3c/PNNxMZGYkkSUyePJm0tDRMJlO9+9X2naztM3fn+17X93vHjh1MnDiRCRMm8Ne//rXBf5wam0hwOJcyTEtLIz09HaDWpQzdkZOTw6RJkxg3bhzJycmXXYe1Pv/4xz/46quvXH/RwPlX7kKS8sTNN99Mbm4uhw4dAmD//v1YrdbLrilZmy5durB//37y8vIA+OGHH7Db7W4/qazNsGHD2LVrl+sX7p///CfDhw/3+Ar6yJEj3H///Tz99NMNboJfsHnzZj7//HPX0121Ws2GDRuIj4/3uK5bbrmFgwcPkpmZCcCWLVuIj48nODjYo3q6dOnC7t27KS8vB+A///kPCQkJ9X5P6/pO3nLLLXz55ZeUlZUhyzKffPKJ6x6tp3Xt3LmT2bNn8/bbbzN58mSPzqmpiXtw1FzK0GKxEBoa6lrK0BMrV66krKyMjRs3snHjRtf29957z3UD/WoICwtjxYoVLFq0iKqqKtRqNUuXLm3QmL++ffvy6KOPMmXKFLRaLX5+fixfvvyKxg+2b9+e2bNn8+c//xmbzUbr1q155ZVXPK5n6dKlyLLMe++9x3vvvefafvFqbVdDx44dWbhwIY8//jh2ux1/f3+WLl3qcQIeO3Ysubm53HXXXej1eoxGI8uXL693n/q+k3feeScTJ07EbrfTs2fPy94Oqasus9mMJEksWrTItS0iIoK//e1vHp1fUxBjUQVB8FqiiSoIgtcSCU4QBK8lEpwgCF5LJDhBELyWSHCCIHgtkeAEQfBaIsEJguC1RIITWpTExER+/vnnqx2G4CVEghMEwWuJBCdcM3Jycnj00UcZPHgw3bp1Y8yYMRw4cABwTkTw+uuvu8pemKV2z549gHN88W233UavXr249957OXnypKtsYmIiL730EjfeeGON4UbCtU8kOOGa8fzzzxMZGcmWLVv44Ycf6Ny5syup3XHHHXz11Veust999x0qlYp+/fqxZcsWlixZwuLFi/n2228ZMWIE999/f42JSE0mE6mpqTz22GPNfl5C0xEJTrhmLFy4kKeeegqA7OxsAgMDXTOQjBgxApPJxMGDBwHYuHEjo0ePRqVSkZKSwqRJk+jSpQtarZZ7770XPz8/du7c6ap75MiR6HQ6AgICmv28hKYjZhMRrhmnT5/mtddeIycnh7Zt22IwGFyTPfr6+nLrrbfy5Zdf0qlTJ7Zu3conn3wCOJu2+/btY9WqVa667HY7OTk5rtcNWWBIaPlEghOuCTabjenTp/PCCy+4JgBdu3ZtjXtpSUlJzJo1i969e5OQkOCaajwyMpJJkyYxadIkV9nMzMwa09G3lAkahcYlmqhCi1NcXExeXp7rX0FBAVarlerqatciK8ePH2fVqlVYrVbXfn369EGj0bB8+fIaU43feeedrFq1ivT0dBRFYdu2bYwaNYrTp083+7kJzUtcwQktzoMPPljjdXBwMHv37mXevHksWrSIuXPnEhsby7hx43jjjTcoLi7GaDSiUqkYPXo0q1atYtSoUa79R40aRXl5OTNmzHBNzZ2cnMz111/f3KcmNDMx4aXgVT755BN27NjBihUrrnYoQgsgmqiCVyguLubIkSN88MEHTJgw4WqHI7QQIsEJXuHgwYPcc8899OrViyFDhlztcIQWQjRRBUHwWuIKThAEryUSnCAIXkskOEEQvJZIcIIgeC2R4ARB8FoiwQmC4LVEghMEwWuJBCcIgtf6f36ruv6XBwyjAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 400.875x150 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "data = []\n",
    "for layer_idx in range(all_gpt2_prob.shape[0]):\n",
    "    for prob in all_gpt2_prob[layer_idx]:\n",
    "        data.append(\n",
    "            {\n",
    "                \"Layer\": layer_idx,\n",
    "                \"Model\": \"GPT2\",\n",
    "                \"Probability\": prob.item(),\n",
    "            }\n",
    "        )\n",
    "\n",
    "for layer_idx in range(all_dpo_prob.shape[0]):\n",
    "    for prob in all_dpo_prob[layer_idx]:\n",
    "        data.append(\n",
    "            {\n",
    "                \"Layer\": layer_idx,\n",
    "                \"Model\": \"DPO\",\n",
    "                \"Probability\": prob.item(),\n",
    "            }\n",
    "        )\n",
    "\n",
    "data = pd.DataFrame(data)\n",
    "\n",
    "sns.set_theme(context=\"paper\", style=\"ticks\", rc={\"lines.linewidth\": 1.5})\n",
    "fig = sns.relplot(\n",
    "    data=data,\n",
    "    x=\"Layer\",\n",
    "    y=\"Probability\",\n",
    "    hue=\"Model\",\n",
    "    hue_order=[\"GPT2\", \"DPO\"],\n",
    "    kind=\"line\",\n",
    "    height=1.5,\n",
    "    aspect=3.25 / 1.5,\n",
    ")\n",
    "\n",
    "\n",
    "major_tick_locs, major_labels = plt.xticks()\n",
    "minor_tick_locs, minor_labels = plt.xticks(minor=True)\n",
    "\n",
    "fig.ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "fig.ax.xaxis.set_major_formatter(ticker.ScalarFormatter())\n",
    "major_tick_locs, major_labels = plt.xticks()\n",
    "new_tick_locs = [\n",
    "    x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 == 0\n",
    "]\n",
    "new_minor_tick_locs = [\n",
    "    x for x in major_tick_locs if x >= 0 and x <= 48 and x % 2 != 0\n",
    "]\n",
    "\n",
    "\n",
    "major_labels = [x if x % 2 == 0 else \"\" for x in range(24)] + [\"F\"]\n",
    "\n",
    "plt.xticks(ticks=new_tick_locs, labels=major_labels)\n",
    "plt.xticks(\n",
    "    ticks=new_minor_tick_locs,\n",
    "    labels=[\"\" for _ in new_minor_tick_locs],\n",
    "    minor=True,\n",
    ")\n",
    "\n",
    "fig.ax.tick_params(axis=\"x\", which=\"major\", length=10)\n",
    "fig.ax.tick_params(axis=\"x\", which=\"both\", color=\"grey\")\n",
    "fig.ax.set_ylim(ymin=0)\n",
    "fig.ax.set_ylim(ymax=0.48)\n",
    "fig.ax.fill_betweenx([0, 0.48], 37, 38, alpha=0.35, facecolor=\"grey\")\n",
    "fig.ax.fill_betweenx([0, 0.48], 39, 40, alpha=0.35, facecolor=\"grey\")\n",
    "fig.ax.fill_betweenx([0, 0.48], 41, 42, alpha=0.35, facecolor=\"grey\")\n",
    "sns.move_legend(fig, \"upper left\", bbox_to_anchor=(0.22, 1))\n",
    "\n",
    "plt.savefig(\"logitlens.pdf\", bbox_inches=\"tight\", dpi=1200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77b8e31f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
